package spark.pipeline


import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}


//import sqlContext.implicits._ 语句需要放在获取sqlContext对象的语句之后
//case class class(...:type ,.....) 的定义需要放在方法的作用域之外（即Java的成员变量位置）
object IrisLogisticRegression {
  case class iris(features : Vector, label : String)
  def main(args:Array[String]): Unit ={
    val spark = SparkSession.builder().appName("iris").master("local").getOrCreate()
    val sc = spark.sparkContext

    val file = "C:/Users/Lenovo/Desktop/Working/Python/data/iris.txt"
    val data = sc.textFile(file)  //一定要确保txt文件没有空行，格式要保持一致
    import spark.implicits._

    val irisData = data
      .map(_.split(","))
      .map(p=>iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString))
      .toDF()
      .cache()
//    irisData.show(false)

//    将花的品种转换为index标签，易识别
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(irisData)
    val labelIris = labelIndexer.transform(irisData)
    labelIris.show(false)

//    获取特征列
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .fit(labelIris)

    val featuresIris = featureIndexer.transform(labelIris)

    featuresIris.show(false)

//    设置逻辑斯蒂回归算法参数，循环100次，规范化为0.3,参数设置可以在explainParams()函数了解
    val lr = new LogisticRegression()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setMaxIter(100)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
      .fit(featuresIris)
//    println("LogisticRegression parameters: \n"+ lr.explainParams()+"\n")

//    将测试集中的indexedLabel转换为花的品种label，属于index->String
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

//    构建机器学习流水线,上阶段输出为本阶段输入
    val lrPipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, lr, labelConverter))

//    0.8训练集,0.2测试集。机器学习Pipeline本质是评估器，调用fit()时候会产生PipelineModel
//    PipelineModel是转换器,调用transform()来预测，产生新的DataFrame,即利用训练得到的模型对测试集进行验证

    val Array(trainingData, testData) = irisData.randomSplit(Array(0.8, 0.2))
    val lrPipelineModel = lrPipeline.fit(trainingData)
    val lrPredictions = lrPipelineModel.transform(testData)

    lrPredictions.show(false)

//    lrPredictions
//      .select("predictedLabel", "label", "features", "probability")
//      .collect()
//      .foreach{case Row(predictedLabel:String, label:String, features:Vector, prob:Vector)=>
//        println(s"($label, $features)-->prob=$prob, predicted Label = $predictedLabel")}

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
    val lrAccuracy = evaluator.evaluate(lrPredictions)
    println("预测的正确率为:"+lrAccuracy)  //在0.8上下浮动
  }
}
